{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Amplitude Perturbation Visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "In this tutorial, we show how to use perturbations of the input amplitudes to learn something about the trained convolutional networks. For more background, see [Deep learning with convolutional neural networks for EEG decoding and visualization](https://arxiv.org/abs/1703.05051), Section A.5.2.\n",
    "\n",
    "First we will do some cross-subject decoding, again using the [Physiobank EEG Motor Movement/Imagery Dataset](https://www.physionet.org/physiobank/database/eegmmidb/), this time to decode imagined left hand vs. imagined right hand movement.\n",
    "\n",
    "\n",
    "<div class=\"alert alert-warning\">\n",
    "\n",
    "This tutorial might be very slow if you are not using a GPU.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Enable logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import importlib\n",
    "importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195\n",
    "log = logging.getLogger()\n",
    "log.setLevel('INFO')\n",
    "import sys\n",
    "logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',\n",
    "                     level=logging.INFO, stream=sys.stdout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mne\n",
    "import numpy as np\n",
    "from mne.io import concatenate_raws\n",
    "from braindecode.datautil.signal_target import SignalAndTarget\n",
    "\n",
    "# First 50 subjects as train\n",
    "physionet_paths = [ mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(1,51)]\n",
    "physionet_paths = np.concatenate(physionet_paths)\n",
    "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')\n",
    "         for path in physionet_paths] \n",
    "\n",
    "raw = concatenate_raws(parts)\n",
    "\n",
    "picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n",
    "                   exclude='bads')\n",
    "\n",
    "events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')\n",
    "\n",
    "# Read epochs (train will be done only between 1 and 2s)\n",
    "# Testing will be done with a running classifier\n",
    "epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks,\n",
    "                baseline=None, preload=True)\n",
    "\n",
    "# 51-55 as validation subjects\n",
    "physionet_paths_valid = [mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(51,56)]\n",
    "physionet_paths_valid = np.concatenate(physionet_paths_valid)\n",
    "parts_valid = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')\n",
    "         for path in physionet_paths_valid]\n",
    "raw_valid = concatenate_raws(parts_valid)\n",
    "\n",
    "picks_valid = mne.pick_types(raw_valid.info, meg=False, eeg=True, stim=False, eog=False,\n",
    "                   exclude='bads')\n",
    "\n",
    "events_valid = mne.find_events(raw_valid, shortest_event=0, stim_channel='STI 014')\n",
    "\n",
    "# Read epochs (train will be done only between 1 and 2s)\n",
    "# Testing will be done with a running classifier\n",
    "epoched_valid = mne.Epochs(raw_valid, events_valid, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks_valid,\n",
    "                baseline=None, preload=True)\n",
    "\n",
    "train_X = (epoched.get_data() * 1e6).astype(np.float32)\n",
    "train_y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1\n",
    "valid_X = (epoched_valid.get_data() * 1e6).astype(np.float32)\n",
    "valid_y = (epoched_valid.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1\n",
    "train_set = SignalAndTarget(train_X, y=train_y)\n",
    "valid_set = SignalAndTarget(valid_X, y=valid_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the deep ConvNet from [Deep learning with convolutional neural networks for EEG decoding and visualization](https://arxiv.org/abs/1703.05051) (Section 2.4.2).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.models.deep4 import Deep4Net\n",
    "from torch import nn\n",
    "from braindecode.torch_ext.util import set_random_seeds\n",
    "from braindecode.models.util import to_dense_prediction_model\n",
    "\n",
    "# Set if you want to use GPU\n",
    "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n",
    "cuda = True\n",
    "set_random_seeds(seed=20170629, cuda=cuda)\n",
    "\n",
    "# This will determine how many crops are processed in parallel\n",
    "input_time_length = 450\n",
    "# final_conv_length determines the size of the receptive field of the ConvNet\n",
    "model = Deep4Net(in_chans=64, n_classes=2, input_time_length=input_time_length,\n",
    "                 filter_length_3=5, filter_length_4=5,\n",
    "                 pool_time_stride=2,\n",
    "                 stride_before_pool=True,\n",
    "                        final_conv_length=1)\n",
    "if cuda:\n",
    "    model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.torch_ext.optimizers import AdamW\n",
    "import torch.nn.functional as F\n",
    "optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model\n",
    "model.compile(loss=F.nll_loss, optimizer=optimizer,  iterator_seed=1, cropped=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_time_length = 450\n",
    "model.fit(train_set.X, train_set.y, epochs=30, batch_size=64, scheduler='cosine',\n",
    "          input_time_length=input_time_length,\n",
    "         validation_data=(valid_set.X, valid_set.y),)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute correlation: amplitude perturbation - prediction change"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First collect all batches and concatenate them into one array of examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.datautil.iterators import CropsFromTrialsIterator\n",
    "from braindecode.torch_ext.util import np_to_var\n",
    "test_input = np_to_var(np.ones((2, 64, input_time_length, 1), dtype=np.float32))\n",
    "if cuda:\n",
    "    test_input = test_input.cuda()\n",
    "out = model.network(test_input)\n",
    "n_preds_per_input = out.cpu().data.numpy().shape[2]\n",
    "iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,\n",
    "                                  n_preds_per_input=n_preds_per_input)\n",
    "\n",
    "train_batches = list(iterator.get_batches(train_set, shuffle=False))\n",
    "train_X_batches = np.concatenate(list(zip(*train_batches))[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, create a prediction function that wraps the model prediction function and returns the predictions as numpy arrays. We use the predition before the softmax, so we create a new module with all the layers of the old until before the softmax."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from braindecode.torch_ext.util import var_to_np\n",
    "import torch as th\n",
    "new_model = nn.Sequential()\n",
    "for name, module in model.network.named_children():\n",
    "    if name == 'softmax': break\n",
    "    new_model.add_module(name, module)\n",
    "\n",
    "new_model.eval();\n",
    "pred_fn = lambda x: var_to_np(th.mean(new_model(np_to_var(x).cuda())[:,:,:,0], dim=2, keepdim=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from braindecode.visualization.perturbation import compute_amplitude_prediction_correlations\n",
    "\n",
    "\n",
    "amp_pred_corrs = compute_amplitude_prediction_correlations(pred_fn, train_X_batches, n_iterations=12,\n",
    "                                         batch_size=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot correlations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pick out one frequency range and mean correlations within that frequency range to make a scalp plot.\n",
    "Here we use the alpha frequency range."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "amp_pred_corrs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "fs = epoched.info['sfreq']\n",
    "freqs = np.fft.rfftfreq(train_X_batches.shape[2], d=1.0/fs)\n",
    "start_freq = 7\n",
    "stop_freq = 14\n",
    "\n",
    "i_start = np.searchsorted(freqs,start_freq)\n",
    "i_stop = np.searchsorted(freqs, stop_freq) + 1\n",
    "\n",
    "freq_corr = np.mean(amp_pred_corrs[:,i_start:i_stop], axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now get approximate positions of the channels in the 10-20 system."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.datasets.sensor_positions import get_channelpos, CHANNEL_10_20_APPROX\n",
    "\n",
    "ch_names = [s.strip('.') for s in epoched.ch_names]\n",
    "positions = [get_channelpos(name, CHANNEL_10_20_APPROX) for name in ch_names]\n",
    "positions = np.array(positions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot with MNE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm\n",
    "%matplotlib inline\n",
    "max_abs_val = np.max(np.abs(freq_corr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2)\n",
    "class_names = ['Left Hand', 'Right Hand']\n",
    "for i_class in range(2):\n",
    "    ax = axes[i_class]\n",
    "    mne.viz.plot_topomap(freq_corr[:,i_class], positions,\n",
    "                     vmin=-max_abs_val, vmax=max_abs_val, contours=0,\n",
    "                    cmap=cm.coolwarm, axes=ax, show=False);\n",
    "    ax.set_title(class_names[i_class])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot with Braindecode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from braindecode.visualization.plot import ax_scalp\n",
    "\n",
    "fig, axes = plt.subplots(1, 2)\n",
    "class_names = ['Left Hand', 'Right Hand']\n",
    "for i_class in range(2):\n",
    "    ax = axes[i_class]\n",
    "    ax_scalp(freq_corr[:,i_class], ch_names, chan_pos_list=CHANNEL_10_20_APPROX, cmap=cm.coolwarm,\n",
    "            vmin=-max_abs_val, vmax=max_abs_val, ax=ax)\n",
    "    ax.set_title(class_names[i_class])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From these plots we can see the ConvNet clearly learned to use the lateralized response in the alpha band. Note that the positive correlations for the left hand on the left side do not imply an increase of alpha activity for the left hand in the data, see  [Deep learning with convolutional neural networks for EEG decoding and visualization](https://arxiv.org/abs/1703.05051) Result 12 for some notes on interpretability."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset references\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " This dataset was created and contributed to PhysioNet by the developers of the [BCI2000](http://www.schalklab.org/research/bci2000) instrumentation system, which they used in making these recordings. The system is described in:\n",
    " \n",
    "     Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.\n",
    "\n",
    "[PhysioBank](https://physionet.org/physiobank/) is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:\n",
    "\n",
    "    Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220."
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Edit Metadata",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
